import logging
import random
from typing import List
import torch
from torch import nn, Tensor
from mmcv.runner import load_state_dict
# from mmcls.utils import get_root_logger

from ..builder import NECKS, build_loss, build_neck
from ..utils import top_pool
from ..vit.layers import resize_pos_embed

@NECKS.register_module()
class SUMAggregator(nn.Module):
    def __init__(self, power=1, abs=True):
        super().__init__()
        self.power = power
        self.abs = abs
    
    def forward(self, x):
        if self.abs:
            x = torch.abs(x)
        attn = torch.sum(torch.pow(x, self.power), dim=1, keepdim=True)
        return attn




@NECKS.register_module()
class CNNAttentionPooling(nn.Module):
    """
    Pool the feature map with attention generated by ONE LANet.

    process:
        for every input:
            1. project the feature map to a certain dimension.
            2. generate an attention map by ONE LANet
            3. pooling the feature map by the attention map

    """
    def __init__(self,
        in_channel:int, out_channel:int,
        # ratios:List[int], 
        method=dict(type='SUMAggregator'),
        patch_num=0,
        pool_config:dict=None,
        pretrained=None):
        super().__init__()
        self.attn_f = build_neck(method)
        self.proj = nn.Conv2d(in_channel, out_channel, kernel_size=1)
        self.pos_embeds = nn.Parameter(torch.zeros(1, patch_num, out_channel), requires_grad=True)
        self.patch_num = patch_num
        self.pool_config = pool_config
        if pretrained is not None:
            self.init_weights(pretrained)
    
    def init_weights(self, pretrained=None):
        logger = get_root_logger()
        logger.warning(f'{self.__class__.__name__} load pretrain from {pretrained}')
        state_dict = torch.load(pretrained, map_location='cpu')
        if 'state_dict' in state_dict:
            state_dict = state_dict['state_dict']
        pos_embed = state_dict['pos_embed'][:, 1:, ...]     # [1, 197, 768] for small

        patch_num_new = self.pos_embeds.shape[1]
        if patch_num_new != pos_embed.shape[1]:
            logger.warning(f'interpolate pos_embed from {pos_embed.shape[1]} to {patch_num_new}')
            pos_embed_new = resize_pos_embed(pos_embed, self.pos_embeds.shape, 0)
        else:
            pos_embed_new = pos_embed
        load_state_dict(self, dict(pos_embeds=pos_embed_new), strict=False, logger=logger)
    
    def forward(self, input:List[Tensor]):
        assert len(input) == 1
        x = input[0]
        attn2D = self.attn_f(x)
        x = self.proj(x)
        x = x.flatten(start_dim=2).transpose(1, 2)
        x = x + self.pos_embeds

        if self.pool_config is None:
            return dict(x=x, attn=attn2D)
        
        attn = attn2D.flatten(start_dim=2).transpose(1, 2)
        keep_indexes = top_pool(attn, **self.pool_config)
        if keep_indexes is None:
            return dict(x=x, attn=attn2D)
        results = []
        B, N, C = x.shape
        for j in range(B):
            results.append(x[j, keep_indexes[j], ...])
        result =  torch.stack(results)
        return dict(x=result, attn=attn2D)
    

